#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import ipaddress
from base64 import b64decode, b64encode

import win32api
import win32security
from Crypto.Cipher import ARC4
from Crypto.Hash import MD5, SHA256

# XSH配置文件模板
conf_template = (
    "[CONNECTION]\n"
    "Protocol=SSH\n"
    "Host={0}\n"
    "Port={1}\n"
    "Description={2}\n"
    "[CONNECTION:AUTHENTICATION]\n"
    "Method={3}\n"
    "UserName={4}\n"
    "UserKey={5}\n"
    "Password={6}\n"
    "{7}\n"
    "[CONNECTION:SSH]\n"
    "ForwardX11=0\n"
    "[TERMINAL]\n"
    "Cols=80\n"
    "Rows=24\n"
    "ScrollbackSize=5120\n"
    "TerminalNameForEcho=Xshell\n"
    "[TERMINAL:WINDOW]\n"
    "FontSize=8\n"
    "AsianFontSize=8\n"
    "CursorBlink=1\n"
    "CursorBlinkInterval=360\n"
    "[SessionInfo]\n"
    "Version={8}\n"
    "Description=Xshell session file\n"
)


def get_user_info():
    """
    获取当前用户名和SID
    :return: user, sid
    """
    # username = ''
    # if platform.system() == 'Windows':
    #     username = os.environ.get('USERNAME')
    # else:  # Unix/Linux/Mac
    #     username = os.environ.get('USER')
    username = win32api.GetUserName()
    sid = win32security.ConvertSidToStringSid(
        win32security.LookupAccountName(win32api.GetComputerName(), username)[0]
    )
    print("Username: %s, SID: %s" % (username, sid))
    return username, sid


class XShellCrypto(object):

    def __init__(self, version: float = 7.1, master_passwd: str = None):
        self._version = version
        match self._version:
            case v if 0 < v < 5.1:  # 将【!X@s#h$e%l^l&】的MD5值作为RC4的密钥key
                self._key = MD5.new(b'!X@s#h$e%l^l&').digest()
            case 5.1 | 5.2:  # 将用户的【SID】的SHA-256值作为RC4的密钥key
                self._key = SHA256.new(get_user_info()[1].encode()).digest()
            case v if 5.2 < v < 7.0:  # 将用户的【用户名+SID】的SHA-256值作为RC4的密钥key
                user = get_user_info()
                if master_passwd is None:
                    self._key = SHA256.new((user[0] + user[1]).encode()).digest()
                else:
                    self._key = SHA256.new(master_passwd.encode()).digest()
            case v if 7.0 <= v:  # 将用户的【SID倒序 + 用户名】的SHA-256值作为RC4的密钥key
                user = get_user_info()
                if master_passwd is None:
                    self._key = SHA256.new((user[1][::-1] + user[0]).encode()).digest()
                else:
                    self._key = SHA256.new(master_passwd.encode()).digest()
            case _:
                raise RuntimeError('版本无法识别，或不受支持')

    def encrypt_password(self, original_password: str):
        cipher = ARC4.new(self._key)
        if self._version < 5.1:
            return b64encode(cipher.encrypt(original_password.encode())).decode()
        else:
            checksum = SHA256.new(original_password.encode()).digest()
            ciphertext = cipher.encrypt(original_password.encode())
            return b64encode(ciphertext + checksum).decode()

    def decrypt_password(self, encrypted_password: str):
        cipher = ARC4.new(self._key)
        if self._version < 5.1:
            return cipher.decrypt(b64decode(encrypted_password)).decode()
        else:
            data = b64decode(encrypted_password)
            ciphertext, checksum = data[:-SHA256.digest_size], data[-SHA256.digest_size:]
            plaintext = cipher.decrypt(ciphertext)
            if SHA256.new(plaintext).digest() != checksum:
                raise ValueError('Cannot decrypt string. The key is wrong!')
            return plaintext.decode()

    def version(self):
        return self._version


class XFtpCrypto(object):
    def __init__(self, version: float = 7.1, master_passwd: str = None):
        self._version = version
        match self._version:
            case v if 0 < v < 5.1:  # 将【!X@s#c$e%l^l&】的MD5值作为RC4的密钥key（与Xshell的不一样）
                self._key = MD5.new(b'!X@s#c$e%l^l&').digest()
            case 5.1 | 5.2:  # 将用户的【SID】的SHA-256值作为RC4的密钥key
                self._key = SHA256.new(get_user_info()[1].encode()).digest()
            case v if 5.2 < v < 7.0:  # 将用户的【用户名+SID】的SHA-256值作为RC4的密钥key
                user = get_user_info()
                if master_passwd is None:
                    self._key = SHA256.new((user[0] + user[1]).encode()).digest()
                else:
                    self._key = SHA256.new(master_passwd.encode()).digest()
            case v if 7.0 <= v:  # 将用户的【SID倒序 + 用户名】的SHA-256值作为RC4的密钥key
                user = get_user_info()
                if master_passwd is None:
                    self._key = SHA256.new((user[1][::-1] + user[0]).encode()).digest()
                else:
                    self._key = SHA256.new(master_passwd.encode()).digest()
            case _:
                raise RuntimeError('版本无法识别，或不受支持')
        # self._cipher = ARC4.new(self._key)

    def encrypt_password(self, original_password: str):
        cipher = ARC4.new(self._key)
        if self._version < 5.1:
            return b64encode(cipher.encrypt(original_password.encode())).decode()
        else:
            checksum = SHA256.new(original_password.encode()).digest()
            ciphertext = cipher.encrypt(original_password.encode())
            return b64encode(ciphertext + checksum).decode()

    def decrypt_password(self, encrypted_password: str):
        cipher = ARC4.new(self._key)
        if self._version < 5.1:
            return cipher.decrypt(b64decode(encrypted_password)).decode()
        else:
            data = b64decode(encrypted_password)
            ciphertext, checksum = data[:-SHA256.digest_size], data[-SHA256.digest_size:]
            plaintext = cipher.decrypt(ciphertext)
            if SHA256.new(plaintext).digest() != checksum:
                raise ValueError('Cannot decrypt string. The key is wrong!')
            return plaintext.decode()

    def version(self):
        return self._version


def generate_xshell_conf(crypto: XShellCrypto, server_conf: dict, using_jumpserver: bool, jump_server_conf: dict) -> str:
    """
    生成XSH配置文件内容
    :param crypto: 对应版本的Crypto
    :param server_conf: 服务器SSH配置信息：
        {'host': '10.150.25.20', 'port': '22', 'username': 'hybrisprod', 'password': 'Best1&^%Red6#@123',
         'role': 'Task', 'cluster_id': '0', 'app_dir': '/u01/prod/hybris5601/task7001-0'}
    :param using_jumpserver: 使用跳板机配置
    :param jump_server_conf: 直连传入None，跳板机跳转传入账户信息dict：
        {'host': '10.150.150.52', 'port': '2222','username': '用户名', 'password': '密码', 'user_key': '密钥名'}
    :return: XSH配置文件内容
    """
    use_jumpserver = is_private_ip(server_conf['host']) and using_jumpserver
    method = 0
    if use_jumpserver:
        host = jump_server_conf['host']
        port = str(jump_server_conf['port'])
        username = jump_server_conf['username']
        password = crypto.encrypt_password(jump_server_conf['password'])
        user_key = jump_server_conf['user_key']
        method = jump_server_conf['method']
    else:
        host = server_conf['host']
        port = str(server_conf['port'])
        username = server_conf['username']
        password = crypto.encrypt_password(server_conf['password'])
        user_key = None
    if server_conf['remark']:
        desc = server_conf['remark']
    else:
        desc = ''
    expect_script = generate_expect_script(use_jumpserver, server_conf, jump_server_conf['default_user'])
    return conf_template.format(host, port, desc, method, username, user_key, password, expect_script, crypto.version())


def generate_expect_script(using_jumpserver: bool, server_conf: dict, def_jp_user: str) -> str:
    # UseExpectSend=1
    # ExpectSend_Count=1
    # ExpectSend_Expect_0=[hybrisprod@
    # ExpectSend_Send_0=source /etc/profile;ll
    # # ExpectSend_Hide_0=0  # 默认不隐藏文本，可省略
    except_count = 1
    scripts = []
    if using_jumpserver:
        except_count += 1
        scripts.append(('Opt>', server_conf['host']))
        if server_conf['username'] != def_jp_user:
            except_count += 1
            scripts.append((f'[{def_jp_user}@', f"sudo su - {server_conf['username']}"))
    if server_conf['cd_app_dir']:
        scripts.append((f"[{server_conf['username']}@", f"source /etc/profile;ll;cd {server_conf['app_dir']};pwd"))
    else:
        scripts.append((f"[{server_conf['username']}@", 'source /etc/profile;ll'))
    expect_scripts = ''
    for i in range(except_count):
        expect_scripts += f"ExpectSend_Expect_{i}={scripts[i][0]}\nExpectSend_Send_{i}={scripts[i][1]}\n"
    return f"UseExpectSend=1\nExpectSend_Count={except_count}\n" + expect_scripts


def is_private_ip(ip: str) -> bool:
    ip_addr = ipaddress.ip_address(ip)
    return ip_addr.is_private
